import json
import argparse

# if 'RAY_ADDRESS' in os.environ:
#     del os.environ['RAY_ADDRESS']

# ray._private.utils.reset_ray_address()

from UnitCell_Environment.unitcell_environment.env.utils import load_compositions
from learn import training_task

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Optional app description')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='Number of workers collecting experience')
    parser.add_argument("--algo_config", type=str, default="config.json",
                        help="The config file for tunable parameters of the chosen RL method")
    parser.add_argument("--task_config", type=str, default="task_config.json",
                        help="The config file for the task parameters")
    parser.add_argument("--env_config", type=str, default="env_config.json",
                        help="The config file for the env parameters")
    parser.add_argument("--comp_config", type=str, default="comp_config.json",
                    help="the config file with compositions and structure sizes supported")
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="checkpoint to continue marl from")
    parser.add_argument("--iterations", type=int, default=25000,
                        help="The number of iterations to train the policy")

    args = parser.parse_args()

    with open(args.algo_config, 'r') as f:
        algo_config = json.load(f)

    algo_config["num_workers"] = args.num_workers

    with open(args.env_config, "r") as file:
        env_config = json.load(file)
        env_config["gamma"] = algo_config.get("gamma", 0.99)
        env_config["gpu"] = algo_config.get("num_gpus", 0) > 0
        env_config["agent_levels"] = ["2_stepsize_"]

    with open(args.comp_config, 'r') as f:
        comp_config = json.load(f)

    load_compositions(comp_config)

    checkpoint = training_task(env_config, 
                                algo_config,
                                args.iterations,
                                checkpoint=args.checkpoint)
